import torch

from llm_studio.python_configs.text_causal_language_modeling_config import (
    ConfigProblemBase,
)
from llm_studio.src.models.text_causal_language_modeling_model import Model
from llm_studio.src.utils.modeling_utils import TokenStoppingCriteria, activate_neftune


def test_token_stopping_criteria():
    token_stopping_criteria = TokenStoppingCriteria(
        stop_word_ids=torch.tensor([0, 1, 2, 8]), prompt_input_ids_len=4
    )

    input_ids = torch.tensor(
        [
            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
            [3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
            [4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
            [5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
        ]
    ).long()

    # prompt input len is 4, so generated ids of last sample of the batch are
    # [9, 10, 11, 12, 13, 14], do not trigger stopping criteria
    assert not token_stopping_criteria(input_ids=input_ids, scores=None)

    token_stopping_criteria = TokenStoppingCriteria(
        stop_word_ids=torch.tensor([6]), prompt_input_ids_len=0
    )

    # first item reads [ 0,  1,  2,  3,  4,  5], so do not trigger stopping criteria
    assert not token_stopping_criteria(input_ids=input_ids[:, :6], scores=None)
    assert token_stopping_criteria(input_ids=input_ids[:, :7], scores=None)

    # Test stopping criteria with compound tokens
    token_stopping_criteria = TokenStoppingCriteria(
        stop_word_ids=torch.tensor([[6, 7]]), prompt_input_ids_len=0
    )

    assert not token_stopping_criteria(input_ids=input_ids[:, :6], scores=None)
    assert not token_stopping_criteria(input_ids=input_ids[:, :7], scores=None)
    assert token_stopping_criteria(input_ids=input_ids[:, :8], scores=None)

    # Test stopping criteria with stop word ids being longer than generated text
    token_stopping_criteria = TokenStoppingCriteria(
        stop_word_ids=torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]]),
        prompt_input_ids_len=0,
    )

    assert not token_stopping_criteria(input_ids=input_ids, scores=None)


def test_neftune_is_disabled_in_inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cfg = ConfigProblemBase(llm_backbone="h2oai/llama2-0b-unit-test")
    cfg.architecture.backbone_dtype = "float32"
    model = Model(cfg).eval().to(device)

    input_batch = {
        "input_ids": torch.randint(
            0,
            1000,
            (1, 10),
        ).to(device),
        "attention_mask": torch.ones(1, 10).to(device),
    }

    with torch.no_grad():
        outputs = model.backbone(**input_batch)

    activate_neftune(model, neftune_noise_alpha=10)
    assert model.backbone.get_input_embeddings().neftune_noise_alpha == 10

    with torch.no_grad():
        outputs_after_neftune = model.backbone(**input_batch)

    assert torch.allclose(outputs["logits"], outputs_after_neftune["logits"])

    # state dict does not contain neftune noise
    assert [key for key in model.state_dict() if "neftune" in key] == []
